{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Maze_DQN.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "1GIz_EbYW5_6",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "a7b11c13-d440-4c66-a629-63ad5983e652"
      },
      "source": [
        "%clear all\n",
        "\n",
        "from __future__ import print_function\n",
        "import os, sys, time, datetime, json, random\n",
        "import numpy as np\n",
        "from keras.models import Sequential\n",
        "from keras.layers.core import Dense, Activation\n",
        "from keras.optimizers import SGD , Adam, RMSprop\n",
        "from keras.layers.advanced_activations import PReLU\n",
        "import matplotlib.pyplot as plt\n",
        "%matplotlib inline\n",
        "\n",
        "maze = np.array([\n",
        "    [ 1.,  0.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],\n",
        "    [ 1.,  1.,  1.,  1.,  1.,  0.,  1.,  1.,  1.,  1.],\n",
        "    [ 1.,  1.,  1.,  1.,  1.,  0.,  1.,  1.,  1.,  1.],\n",
        "    [ 0.,  0.,  1.,  0.,  0.,  1.,  0.,  1.,  1.,  1.],\n",
        "    [ 1.,  1.,  0.,  1.,  0.,  1.,  0.,  0.,  0.,  1.],\n",
        "    [ 1.,  1.,  0.,  1.,  0.,  1.,  1.,  1.,  1.,  1.],\n",
        "    [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],\n",
        "    [ 1.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.],\n",
        "    [ 1.,  0.,  0.,  0.,  0.,  0.,  1.,  1.,  1.,  1.],\n",
        "    [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  0.,  1.,  1.]\n",
        "])\n",
        "\n",
        "visited_mark = 0.8  # Cells visited by the rat will be painted by gray 0.8\n",
        "rat_mark = 0.5      # The current rat cell will be painteg by gray 0.5\n",
        "LEFT = 0\n",
        "UP = 1\n",
        "RIGHT = 2\n",
        "DOWN = 3\n",
        "\n",
        "# Actions dictionary\n",
        "actions_dict = {\n",
        "    LEFT: 'left',\n",
        "    UP: 'up',\n",
        "    RIGHT: 'right',\n",
        "    DOWN: 'down',\n",
        "}\n",
        "\n",
        "num_actions = len(actions_dict)\n",
        "\n",
        "# Exploration factor\n",
        "epsilon = 0.1\n",
        "\n",
        "\n",
        "\n"
      ],
      "execution_count": 14,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "\u001b[H\u001b[2J"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gGRkPa_VXbAU",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "outputId": "0447b73c-bb0d-460d-eb1c-a602a9e59eee"
      },
      "source": [
        "# maze is a 2d Numpy array of floats between 0.0 to 1.0\n",
        "# 1.0 corresponds to a free cell, and 0.0 an occupied cell\n",
        "# rat = (row, col) initial rat position (defaults to (0,0))\n",
        "\n",
        "class Qmaze(object):\n",
        "    def __init__(self, maze, rat=(0,0)):\n",
        "        self._maze = np.array(maze)\n",
        "        nrows, ncols = self._maze.shape\n",
        "        self.target = (nrows-1, ncols-1)   # target cell where the \"cheese\" is\n",
        "        self.free_cells = [(r,c) for r in range(nrows) for c in range(ncols) if self._maze[r,c] == 1.0]\n",
        "        self.free_cells.remove(self.target)\n",
        "        if self._maze[self.target] == 0.0:\n",
        "            raise Exception(\"Invalid maze: target cell cannot be blocked!\")\n",
        "        if not rat in self.free_cells:\n",
        "            raise Exception(\"Invalid Rat Location: must sit on a free cell\")\n",
        "        self.reset(rat)\n",
        "\n",
        "    def reset(self, rat):\n",
        "        self.rat = rat\n",
        "        self.maze = np.copy(self._maze)\n",
        "        nrows, ncols = self.maze.shape\n",
        "        row, col = rat\n",
        "        self.maze[row, col] = rat_mark\n",
        "        self.state = (row, col, 'start')\n",
        "        self.min_reward = -0.5 * self.maze.size\n",
        "        self.total_reward = 0\n",
        "        self.visited = set()\n",
        "\n",
        "    def update_state(self, action):\n",
        "        nrows, ncols = self.maze.shape\n",
        "        nrow, ncol, nmode = rat_row, rat_col, mode = self.state\n",
        "\n",
        "        if self.maze[rat_row, rat_col] > 0.0:\n",
        "            self.visited.add((rat_row, rat_col))  # mark visited cell\n",
        "\n",
        "        valid_actions = self.valid_actions()\n",
        "                \n",
        "        if not valid_actions:\n",
        "            nmode = 'blocked'\n",
        "        elif action in valid_actions:\n",
        "            nmode = 'valid'\n",
        "            if action == LEFT:\n",
        "                ncol -= 1\n",
        "            elif action == UP:\n",
        "                nrow -= 1\n",
        "            if action == RIGHT:\n",
        "                ncol += 1\n",
        "            elif action == DOWN:\n",
        "                nrow += 1\n",
        "        else:                  # invalid action, no change in rat position\n",
        "            mode = 'invalid'\n",
        "\n",
        "        # new state\n",
        "        self.state = (nrow, ncol, nmode)\n",
        "\n",
        "    def get_reward(self):\n",
        "        rat_row, rat_col, mode = self.state\n",
        "        nrows, ncols = self.maze.shape\n",
        "        if rat_row == nrows-1 and rat_col == ncols-1:\n",
        "            return 1.0\n",
        "        if mode == 'blocked':\n",
        "            return self.min_reward - 1\n",
        "        if (rat_row, rat_col) in self.visited:\n",
        "            return -0.25\n",
        "        if mode == 'invalid':\n",
        "            return -0.75\n",
        "        if mode == 'valid':\n",
        "            return -0.04\n",
        "\n",
        "    def act(self, action):\n",
        "        self.update_state(action)\n",
        "        reward = self.get_reward()\n",
        "        self.total_reward += reward\n",
        "        status = self.game_status()\n",
        "        envstate = self.observe()\n",
        "        return envstate, reward, status\n",
        "\n",
        "    def observe(self):\n",
        "        canvas = self.draw_env()\n",
        "        envstate = canvas.reshape((1, -1))\n",
        "        return envstate\n",
        "\n",
        "    def draw_env(self):\n",
        "        canvas = np.copy(self.maze)\n",
        "        nrows, ncols = self.maze.shape\n",
        "        # clear all visual marks\n",
        "        for r in range(nrows):\n",
        "            for c in range(ncols):\n",
        "                if canvas[r,c] > 0.0:\n",
        "                    canvas[r,c] = 1.0\n",
        "        # draw the rat\n",
        "        row, col, valid = self.state\n",
        "        canvas[row, col] = rat_mark\n",
        "        return canvas\n",
        "\n",
        "    def game_status(self):\n",
        "        if self.total_reward < self.min_reward:\n",
        "            return 'lose'\n",
        "        rat_row, rat_col, mode = self.state\n",
        "        nrows, ncols = self.maze.shape\n",
        "        if rat_row == nrows-1 and rat_col == ncols-1:\n",
        "            return 'win'\n",
        "\n",
        "        return 'not_over'\n",
        "\n",
        "    def valid_actions(self, cell=None):\n",
        "        if cell is None:\n",
        "            row, col, mode = self.state\n",
        "        else:\n",
        "            row, col = cell\n",
        "        actions = [0, 1, 2, 3]\n",
        "        nrows, ncols = self.maze.shape\n",
        "        if row == 0:\n",
        "            actions.remove(1)\n",
        "        elif row == nrows-1:\n",
        "            actions.remove(3)\n",
        "\n",
        "        if col == 0:\n",
        "            actions.remove(0)\n",
        "        elif col == ncols-1:\n",
        "            actions.remove(2)\n",
        "\n",
        "        if row>0 and self.maze[row-1,col] == 0.0:\n",
        "            actions.remove(1)\n",
        "        if row<nrows-1 and self.maze[row+1,col] == 0.0:\n",
        "            actions.remove(3)\n",
        "\n",
        "        if col>0 and self.maze[row,col-1] == 0.0:\n",
        "            actions.remove(0)\n",
        "        if col<ncols-1 and self.maze[row,col+1] == 0.0:\n",
        "            actions.remove(2)\n",
        "\n",
        "        return actions\n",
        "\n",
        "def show(qmaze):\n",
        "    plt.grid('on')\n",
        "    nrows, ncols = qmaze.maze.shape\n",
        "    ax = plt.gca()\n",
        "    ax.set_xticks(np.arange(0.5, nrows, 1))\n",
        "    ax.set_yticks(np.arange(0.5, ncols, 1))\n",
        "    ax.set_xticklabels([])\n",
        "    ax.set_yticklabels([])\n",
        "    canvas = np.copy(qmaze.maze)\n",
        "    for row,col in qmaze.visited:\n",
        "        canvas[row,col] = 0.6\n",
        "    rat_row, rat_col, _ = qmaze.state\n",
        "    canvas[rat_row, rat_col] = 0.3   # rat cell\n",
        "    canvas[nrows-1, ncols-1] = 0.9 # cheese cell\n",
        "    img = plt.imshow(canvas, interpolation='none', cmap='gray')\n",
        "    return img\n",
        "\n",
        "maze = [\n",
        "    [ 1.,  0.,  1.,  1.,  1.,  1.,  1.,  1.],\n",
        "    [ 1.,  0.,  1.,  1.,  1.,  0.,  1.,  1.],\n",
        "    [ 1.,  1.,  1.,  1.,  0.,  1.,  0.,  1.],\n",
        "    [ 1.,  1.,  1.,  0.,  1.,  1.,  1.,  1.],\n",
        "    [ 1.,  1.,  0.,  1.,  1.,  1.,  1.,  1.],\n",
        "    [ 1.,  1.,  1.,  0.,  1.,  0.,  0.,  0.],\n",
        "    [ 1.,  1.,  1.,  0.,  1.,  1.,  1.,  1.],\n",
        "    [ 1.,  1.,  1.,  1.,  0.,  1.,  1.,  1.]\n",
        "]\n",
        "\n",
        "qmaze = Qmaze(maze)\n",
        "canvas, reward, game_over = qmaze.act(DOWN)\n",
        "print(\"reward=\", reward)\n",
        "show(qmaze)\n",
        "\n",
        "\n",
        "qmaze.act(DOWN)  # move down\n",
        "qmaze.act(RIGHT)  # move right\n",
        "qmaze.act(RIGHT)  # move right\n",
        "qmaze.act(RIGHT)  # move right\n",
        "qmaze.act(UP)  # move up\n",
        "show(qmaze)\n",
        "\n",
        "\n",
        "def play_game(model, qmaze, rat_cell):\n",
        "    qmaze.reset(rat_cell)\n",
        "    envstate = qmaze.observe()\n",
        "    while True:\n",
        "        prev_envstate = envstate\n",
        "        # get next action\n",
        "        q = model.predict(prev_envstate)\n",
        "        action = np.argmax(q[0])\n",
        "\n",
        "        # apply action, get rewards and new state\n",
        "        envstate, reward, game_status = qmaze.act(action)\n",
        "        if game_status == 'win':\n",
        "            return True\n",
        "        elif game_status == 'lose':\n",
        "            return False\n",
        "\n",
        "\n",
        "\n",
        "def completion_check(model, qmaze):\n",
        "    for cell in qmaze.free_cells:\n",
        "        if not qmaze.valid_actions(cell):\n",
        "            return False\n",
        "        if not play_game(model, qmaze, cell):\n",
        "            return False\n",
        "    return True\n",
        "\n",
        "\n",
        "\n",
        "class Experience(object):\n",
        "    def __init__(self, model, max_memory=100, discount=0.95):\n",
        "        self.model = model\n",
        "        self.max_memory = max_memory\n",
        "        self.discount = discount\n",
        "        self.memory = list()\n",
        "        self.num_actions = model.output_shape[-1]\n",
        "\n",
        "    def remember(self, episode):\n",
        "        # episode = [envstate, action, reward, envstate_next, game_over]\n",
        "        # memory[i] = episode\n",
        "        # envstate == flattened 1d maze cells info, including rat cell (see method: observe)\n",
        "        self.memory.append(episode)\n",
        "        if len(self.memory) > self.max_memory:\n",
        "            del self.memory[0]\n",
        "\n",
        "    def predict(self, envstate):\n",
        "        return self.model.predict(envstate)[0]\n",
        "\n",
        "    def get_data(self, data_size=10):\n",
        "        env_size = self.memory[0][0].shape[1]   # envstate 1d size (1st element of episode)\n",
        "        mem_size = len(self.memory)\n",
        "        data_size = min(mem_size, data_size)\n",
        "        inputs = np.zeros((data_size, env_size))\n",
        "        targets = np.zeros((data_size, self.num_actions))\n",
        "        for i, j in enumerate(np.random.choice(range(mem_size), data_size, replace=False)):\n",
        "            envstate, action, reward, envstate_next, game_over = self.memory[j]\n",
        "            inputs[i] = envstate\n",
        "            # There should be no target values for actions not taken.\n",
        "            targets[i] = self.predict(envstate)\n",
        "            #print(targets[i])\n",
        "            # Q_sa = derived policy = max quality env/action = max_a' Q(s', a')\n",
        "            Q_sa = np.max(self.predict(envstate_next))\n",
        "            if game_over:\n",
        "                targets[i, action] = reward\n",
        "            else:\n",
        "                # reward + gamma * max_a' Q(s', a')\n",
        "                targets[i, action] = reward + self.discount * Q_sa\n",
        "        return inputs, targets\n",
        "def qtrain(model, maze, **opt):\n",
        "    global epsilon\n",
        "    n_epoch = opt.get('n_epoch', 15000)\n",
        "    max_memory = opt.get('max_memory', 1000)\n",
        "    data_size = opt.get('data_size', 50)\n",
        "    weights_file = opt.get('weights_file', \"\")\n",
        "    name = opt.get('name', 'model')\n",
        "    start_time = datetime.datetime.now()\n",
        "\n",
        "    # If you want to continue training from a previous model,\n",
        "    # just supply the h5 file name to weights_file option\n",
        "    if weights_file:\n",
        "        print(\"loading weights from file: %s\" % (weights_file,))\n",
        "        model.load_weights(weights_file)\n",
        "\n",
        "    # Construct environment/game from numpy array: maze (see above)\n",
        "    qmaze = Qmaze(maze)\n",
        "\n",
        "    # Initialize experience replay object\n",
        "    experience = Experience(model, max_memory=max_memory)\n",
        "\n",
        "    win_history = []   # history of win/lose game\n",
        "    n_free_cells = len(qmaze.free_cells)\n",
        "    hsize = qmaze.maze.size//2   # history window size\n",
        "    win_rate = 0.0\n",
        "    imctr = 1\n",
        "    wrdb=[]\n",
        "    edb=[]\n",
        "    epdb=[]  \n",
        "    for epoch in range(n_epoch):\n",
        "        loss = 0.0\n",
        "        rat_cell = random.choice(qmaze.free_cells)\n",
        "        qmaze.reset(rat_cell)\n",
        "        game_over = False\n",
        "\n",
        "        # get initial envstate (1d flattened canvas)\n",
        "        envstate = qmaze.observe()\n",
        "\n",
        "        n_episodes = 0\n",
        "        while not game_over:\n",
        "            valid_actions = qmaze.valid_actions()\n",
        "            if not valid_actions: break\n",
        "            prev_envstate = envstate\n",
        "            # Get next action\n",
        "            if np.random.rand() < epsilon:\n",
        "                action = random.choice(valid_actions)\n",
        "            else:\n",
        "                action = np.argmax(experience.predict(prev_envstate))\n",
        "\n",
        "            # Apply action, get reward and new envstate\n",
        "            envstate, reward, game_status = qmaze.act(action)\n",
        "            if game_status == 'win':\n",
        "                win_history.append(1)\n",
        "                game_over = True\n",
        "            elif game_status == 'lose':\n",
        "                win_history.append(0)\n",
        "                game_over = True\n",
        "            else:\n",
        "                game_over = False\n",
        "\n",
        "            # Store episode (experience)\n",
        "            episode = [prev_envstate, action, reward, envstate, game_over]\n",
        "            experience.remember(episode)\n",
        "            n_episodes += 1\n",
        "\n",
        "            # Train neural network model\n",
        "            inputs, targets = experience.get_data(data_size=data_size)\n",
        "            h = model.fit(\n",
        "                inputs,\n",
        "                targets,\n",
        "                epochs=8,\n",
        "                batch_size=16,\n",
        "                verbose=0,\n",
        "            )\n",
        "            loss = model.evaluate(inputs, targets, verbose=0)\n",
        "\n",
        "        if len(win_history) > hsize:\n",
        "            win_rate = sum(win_history[-hsize:]) / hsize\n",
        "    \n",
        "        dt = datetime.datetime.now() - start_time\n",
        "        t = format_time(dt.total_seconds())\n",
        "        template = \"Epoch: {:03d}/{:d} | Loss: {:.4f} | Episodes: {:d} | Win count: {:d} | Win rate: {:.3f} | time: {}\"\n",
        "        print(template.format(epoch, n_epoch-1, loss, n_episodes, sum(win_history), win_rate, t))\n",
        "        # we simply check if training has exhausted all free cells and if in all\n",
        "        # cases the agent won\n",
        "        wrdb.append(sum(win_history))\n",
        "        edb.append(epoch)\n",
        "        epdb.append(n_episodes)\n",
        "        if win_rate > 0.9 : epsilon = 0.05\n",
        "        if sum(win_history[-hsize:]) == hsize and completion_check(model, qmaze):\n",
        "            print(\"Reached 100%% win rate at epoch: %d\" % (epoch,))\n",
        "            \n",
        "            np.save('/content/drive/My Drive/db/dqn/maze/dqn/wrdb',wrdb)\n",
        "            np.save('/content/drive/My Drive/db/dqn/maze/dqn/edb',edb)\n",
        "            np.save('/content/drive/My Drive/db/dqn/maze/dqn/epdb',epdb)\n",
        "            break\n",
        "\n",
        "              # Save trained model weights and architecture, this will be used by the visualization code\n",
        "    h5file = name + \".h5\"\n",
        "    json_file = name + \".json\"\n",
        "    model.save_weights(h5file, overwrite=True)\n",
        "    with open(json_file, \"w\") as outfile:\n",
        "        json.dump(model.to_json(), outfile)\n",
        "    end_time = datetime.datetime.now()\n",
        "    dt = datetime.datetime.now() - start_time\n",
        "    seconds = dt.total_seconds()\n",
        "    t = format_time(seconds)\n",
        "    print('files: %s, %s' % (h5file, json_file))\n",
        "    print(\"n_epoch: %d, max_mem: %d, data: %d, time: %s\" % (epoch, max_memory, data_size, t))\n",
        "    return seconds\n",
        "\n",
        "\n",
        "# This is a small utility for printing readable time strings:\n",
        "def format_time(seconds):\n",
        "    if seconds < 400:\n",
        "        s = float(seconds)\n",
        "        return \"%.1f seconds\" % (s,)\n",
        "    elif seconds < 4000:\n",
        "        m = seconds / 60.0\n",
        "        return \"%.2f minutes\" % (m,)\n",
        "    else:\n",
        "        h = seconds / 3600.0\n",
        "        return \"%.2f hours\" % (h,)\n",
        "\n",
        "\n",
        "def build_model(maze, lr=0.001):\n",
        "    model = Sequential()\n",
        "    model.add(Dense(maze.size, input_shape=(maze.size,)))\n",
        "    model.add(PReLU())\n",
        "    model.add(Dense(maze.size))\n",
        "    model.add(PReLU())\n",
        "    model.add(Dense(num_actions))\n",
        "    model.compile(optimizer='adam', loss='mse')\n",
        "    return model\n",
        "\n",
        "maze =  np.array([\n",
        "    [ 1.,  0.,  1.,  1.,  1.,  1.,  1.],\n",
        "    [ 1.,  1.,  1.,  0.,  0.,  1.,  0.],\n",
        "    [ 0.,  0.,  0.,  1.,  1.,  1.,  0.],\n",
        "    [ 1.,  1.,  1.,  1.,  0.,  0.,  1.],\n",
        "    [ 1.,  0.,  0.,  0.,  1.,  1.,  1.],\n",
        "    [ 1.,  0.,  1.,  1.,  1.,  1.,  1.],\n",
        "    [ 1.,  1.,  1.,  0.,  1.,  1.,  1.]\n",
        "])\n",
        "from google.colab import drive\n",
        "drive.mount('/content/drive')\n",
        "\n",
        "qmaze = Qmaze(maze)\n",
        "show(qmaze)\n",
        "\n",
        "model = build_model(maze)\n",
        "qtrain(model, maze, epochs=1000, max_memory=8*maze.size, data_size=32)"
      ],
      "execution_count": 15,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "reward= -0.04\n",
            "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n",
            "Epoch: 000/14999 | Loss: 0.0024 | Episodes: 59 | Win count: 1 | Win rate: 0.000 | time: 3.4 seconds\n",
            "Epoch: 001/14999 | Loss: 0.0030 | Episodes: 106 | Win count: 1 | Win rate: 0.000 | time: 9.6 seconds\n",
            "Epoch: 002/14999 | Loss: 0.0048 | Episodes: 4 | Win count: 2 | Win rate: 0.000 | time: 9.9 seconds\n",
            "Epoch: 003/14999 | Loss: 0.0018 | Episodes: 4 | Win count: 3 | Win rate: 0.000 | time: 10.1 seconds\n",
            "Epoch: 004/14999 | Loss: 0.0020 | Episodes: 61 | Win count: 4 | Win rate: 0.000 | time: 13.6 seconds\n",
            "Epoch: 005/14999 | Loss: 0.0040 | Episodes: 5 | Win count: 5 | Win rate: 0.000 | time: 13.9 seconds\n",
            "Epoch: 006/14999 | Loss: 0.0030 | Episodes: 12 | Win count: 6 | Win rate: 0.000 | time: 14.7 seconds\n",
            "Epoch: 007/14999 | Loss: 0.0024 | Episodes: 2 | Win count: 7 | Win rate: 0.000 | time: 14.8 seconds\n",
            "Epoch: 008/14999 | Loss: 0.0036 | Episodes: 6 | Win count: 8 | Win rate: 0.000 | time: 15.2 seconds\n",
            "Epoch: 009/14999 | Loss: 0.0069 | Episodes: 8 | Win count: 9 | Win rate: 0.000 | time: 15.7 seconds\n",
            "Epoch: 010/14999 | Loss: 0.0115 | Episodes: 3 | Win count: 10 | Win rate: 0.000 | time: 15.8 seconds\n",
            "Epoch: 011/14999 | Loss: 0.0027 | Episodes: 3 | Win count: 11 | Win rate: 0.000 | time: 16.0 seconds\n",
            "Epoch: 012/14999 | Loss: 0.0028 | Episodes: 28 | Win count: 12 | Win rate: 0.000 | time: 17.7 seconds\n",
            "Epoch: 013/14999 | Loss: 0.0031 | Episodes: 102 | Win count: 12 | Win rate: 0.000 | time: 23.8 seconds\n",
            "Epoch: 014/14999 | Loss: 0.0080 | Episodes: 1 | Win count: 13 | Win rate: 0.000 | time: 23.9 seconds\n",
            "Epoch: 015/14999 | Loss: 0.0044 | Episodes: 74 | Win count: 14 | Win rate: 0.000 | time: 28.2 seconds\n",
            "Epoch: 016/14999 | Loss: 0.0058 | Episodes: 2 | Win count: 15 | Win rate: 0.000 | time: 28.3 seconds\n",
            "Epoch: 017/14999 | Loss: 0.0016 | Episodes: 14 | Win count: 16 | Win rate: 0.000 | time: 29.1 seconds\n",
            "Epoch: 018/14999 | Loss: 0.0014 | Episodes: 5 | Win count: 17 | Win rate: 0.000 | time: 29.4 seconds\n",
            "Epoch: 019/14999 | Loss: 0.0035 | Episodes: 100 | Win count: 17 | Win rate: 0.000 | time: 35.5 seconds\n",
            "Epoch: 020/14999 | Loss: 0.0019 | Episodes: 3 | Win count: 18 | Win rate: 0.000 | time: 35.7 seconds\n",
            "Epoch: 021/14999 | Loss: 0.0017 | Episodes: 106 | Win count: 18 | Win rate: 0.000 | time: 42.4 seconds\n",
            "Epoch: 022/14999 | Loss: 0.0125 | Episodes: 5 | Win count: 19 | Win rate: 0.000 | time: 42.8 seconds\n",
            "Epoch: 023/14999 | Loss: 0.0023 | Episodes: 104 | Win count: 19 | Win rate: 0.000 | time: 48.9 seconds\n",
            "Epoch: 024/14999 | Loss: 0.0030 | Episodes: 106 | Win count: 19 | Win rate: 0.750 | time: 55.1 seconds\n",
            "Epoch: 025/14999 | Loss: 0.0039 | Episodes: 5 | Win count: 20 | Win rate: 0.792 | time: 55.4 seconds\n",
            "Epoch: 026/14999 | Loss: 0.0359 | Episodes: 5 | Win count: 21 | Win rate: 0.792 | time: 55.7 seconds\n",
            "Epoch: 027/14999 | Loss: 0.0028 | Episodes: 104 | Win count: 21 | Win rate: 0.750 | time: 61.6 seconds\n",
            "Epoch: 028/14999 | Loss: 0.0014 | Episodes: 5 | Win count: 22 | Win rate: 0.750 | time: 61.9 seconds\n",
            "Epoch: 029/14999 | Loss: 0.0028 | Episodes: 104 | Win count: 22 | Win rate: 0.708 | time: 67.9 seconds\n",
            "Epoch: 030/14999 | Loss: 0.0634 | Episodes: 107 | Win count: 22 | Win rate: 0.667 | time: 74.0 seconds\n",
            "Epoch: 031/14999 | Loss: 0.0472 | Episodes: 104 | Win count: 22 | Win rate: 0.625 | time: 80.2 seconds\n",
            "Epoch: 032/14999 | Loss: 0.0023 | Episodes: 34 | Win count: 23 | Win rate: 0.625 | time: 82.3 seconds\n",
            "Epoch: 033/14999 | Loss: 0.0011 | Episodes: 107 | Win count: 23 | Win rate: 0.583 | time: 88.5 seconds\n",
            "Epoch: 034/14999 | Loss: 0.0031 | Episodes: 10 | Win count: 24 | Win rate: 0.583 | time: 89.1 seconds\n",
            "Epoch: 035/14999 | Loss: 0.0208 | Episodes: 4 | Win count: 25 | Win rate: 0.583 | time: 89.3 seconds\n",
            "Epoch: 036/14999 | Loss: 0.0480 | Episodes: 2 | Win count: 26 | Win rate: 0.583 | time: 89.4 seconds\n",
            "Epoch: 037/14999 | Loss: 0.0185 | Episodes: 4 | Win count: 27 | Win rate: 0.625 | time: 89.7 seconds\n",
            "Epoch: 038/14999 | Loss: 0.0040 | Episodes: 107 | Win count: 27 | Win rate: 0.583 | time: 96.0 seconds\n",
            "Epoch: 039/14999 | Loss: 0.0053 | Episodes: 10 | Win count: 28 | Win rate: 0.583 | time: 96.6 seconds\n",
            "Epoch: 040/14999 | Loss: 0.0378 | Episodes: 8 | Win count: 29 | Win rate: 0.583 | time: 97.1 seconds\n",
            "Epoch: 041/14999 | Loss: 0.0028 | Episodes: 4 | Win count: 30 | Win rate: 0.583 | time: 97.4 seconds\n",
            "Epoch: 042/14999 | Loss: 0.0035 | Episodes: 1 | Win count: 31 | Win rate: 0.583 | time: 97.4 seconds\n",
            "Epoch: 043/14999 | Loss: 0.0030 | Episodes: 5 | Win count: 32 | Win rate: 0.625 | time: 97.7 seconds\n",
            "Epoch: 044/14999 | Loss: 0.0018 | Episodes: 106 | Win count: 32 | Win rate: 0.583 | time: 104.0 seconds\n",
            "Epoch: 045/14999 | Loss: 0.0010 | Episodes: 56 | Win count: 33 | Win rate: 0.625 | time: 108.5 seconds\n",
            "Epoch: 046/14999 | Loss: 0.0067 | Episodes: 16 | Win count: 34 | Win rate: 0.625 | time: 109.8 seconds\n",
            "Epoch: 047/14999 | Loss: 0.0064 | Episodes: 1 | Win count: 35 | Win rate: 0.667 | time: 109.9 seconds\n",
            "Epoch: 048/14999 | Loss: 0.0087 | Episodes: 9 | Win count: 36 | Win rate: 0.708 | time: 110.8 seconds\n",
            "Epoch: 049/14999 | Loss: 0.0028 | Episodes: 2 | Win count: 37 | Win rate: 0.708 | time: 111.0 seconds\n",
            "Epoch: 050/14999 | Loss: 0.0063 | Episodes: 6 | Win count: 38 | Win rate: 0.708 | time: 111.6 seconds\n",
            "Epoch: 051/14999 | Loss: 0.0029 | Episodes: 106 | Win count: 38 | Win rate: 0.708 | time: 118.6 seconds\n",
            "Epoch: 052/14999 | Loss: 0.0019 | Episodes: 101 | Win count: 38 | Win rate: 0.667 | time: 124.5 seconds\n",
            "Epoch: 053/14999 | Loss: 0.0670 | Episodes: 7 | Win count: 39 | Win rate: 0.708 | time: 125.0 seconds\n",
            "Epoch: 054/14999 | Loss: 0.0033 | Episodes: 104 | Win count: 39 | Win rate: 0.708 | time: 131.1 seconds\n",
            "Epoch: 055/14999 | Loss: 0.0034 | Episodes: 4 | Win count: 40 | Win rate: 0.750 | time: 131.4 seconds\n",
            "Epoch: 056/14999 | Loss: 0.0571 | Episodes: 16 | Win count: 41 | Win rate: 0.750 | time: 132.4 seconds\n",
            "Epoch: 057/14999 | Loss: 0.0536 | Episodes: 5 | Win count: 42 | Win rate: 0.792 | time: 132.7 seconds\n",
            "Epoch: 058/14999 | Loss: 0.0036 | Episodes: 5 | Win count: 43 | Win rate: 0.792 | time: 133.0 seconds\n",
            "Epoch: 059/14999 | Loss: 0.0105 | Episodes: 13 | Win count: 44 | Win rate: 0.792 | time: 133.8 seconds\n",
            "Epoch: 060/14999 | Loss: 0.0045 | Episodes: 9 | Win count: 45 | Win rate: 0.792 | time: 134.4 seconds\n",
            "Epoch: 061/14999 | Loss: 0.0027 | Episodes: 18 | Win count: 46 | Win rate: 0.792 | time: 135.5 seconds\n",
            "Epoch: 062/14999 | Loss: 0.0233 | Episodes: 22 | Win count: 47 | Win rate: 0.833 | time: 136.8 seconds\n",
            "Epoch: 063/14999 | Loss: 0.0086 | Episodes: 12 | Win count: 48 | Win rate: 0.833 | time: 137.5 seconds\n",
            "Epoch: 064/14999 | Loss: 0.0021 | Episodes: 7 | Win count: 49 | Win rate: 0.833 | time: 138.0 seconds\n",
            "Epoch: 065/14999 | Loss: 0.0047 | Episodes: 59 | Win count: 50 | Win rate: 0.833 | time: 141.5 seconds\n",
            "Epoch: 066/14999 | Loss: 0.0030 | Episodes: 2 | Win count: 51 | Win rate: 0.833 | time: 141.6 seconds\n",
            "Epoch: 067/14999 | Loss: 0.0064 | Episodes: 28 | Win count: 52 | Win rate: 0.833 | time: 143.3 seconds\n",
            "Epoch: 068/14999 | Loss: 0.0210 | Episodes: 25 | Win count: 53 | Win rate: 0.875 | time: 144.9 seconds\n",
            "Epoch: 069/14999 | Loss: 0.0018 | Episodes: 10 | Win count: 54 | Win rate: 0.875 | time: 145.5 seconds\n",
            "Epoch: 070/14999 | Loss: 0.0053 | Episodes: 16 | Win count: 55 | Win rate: 0.875 | time: 146.5 seconds\n",
            "Epoch: 071/14999 | Loss: 0.0016 | Episodes: 11 | Win count: 56 | Win rate: 0.875 | time: 147.1 seconds\n",
            "Epoch: 072/14999 | Loss: 0.0007 | Episodes: 23 | Win count: 57 | Win rate: 0.875 | time: 148.6 seconds\n",
            "Epoch: 073/14999 | Loss: 0.0017 | Episodes: 11 | Win count: 58 | Win rate: 0.875 | time: 149.4 seconds\n",
            "Epoch: 074/14999 | Loss: 0.0012 | Episodes: 31 | Win count: 59 | Win rate: 0.875 | time: 151.2 seconds\n",
            "Epoch: 075/14999 | Loss: 0.0014 | Episodes: 34 | Win count: 60 | Win rate: 0.917 | time: 153.1 seconds\n",
            "Epoch: 076/14999 | Loss: 0.0010 | Episodes: 1 | Win count: 61 | Win rate: 0.958 | time: 153.2 seconds\n",
            "Epoch: 077/14999 | Loss: 0.0018 | Episodes: 15 | Win count: 62 | Win rate: 0.958 | time: 154.1 seconds\n",
            "Epoch: 078/14999 | Loss: 0.0035 | Episodes: 23 | Win count: 63 | Win rate: 1.000 | time: 155.5 seconds\n",
            "Epoch: 079/14999 | Loss: 0.0017 | Episodes: 2 | Win count: 64 | Win rate: 1.000 | time: 155.7 seconds\n",
            "Epoch: 080/14999 | Loss: 0.0006 | Episodes: 46 | Win count: 65 | Win rate: 1.000 | time: 158.6 seconds\n",
            "Epoch: 081/14999 | Loss: 0.0012 | Episodes: 2 | Win count: 66 | Win rate: 1.000 | time: 158.8 seconds\n",
            "Epoch: 082/14999 | Loss: 0.0012 | Episodes: 28 | Win count: 67 | Win rate: 1.000 | time: 160.5 seconds\n",
            "Epoch: 083/14999 | Loss: 0.0017 | Episodes: 21 | Win count: 68 | Win rate: 1.000 | time: 161.9 seconds\n",
            "Epoch: 084/14999 | Loss: 0.0021 | Episodes: 29 | Win count: 69 | Win rate: 1.000 | time: 163.7 seconds\n",
            "Epoch: 085/14999 | Loss: 0.0016 | Episodes: 1 | Win count: 70 | Win rate: 1.000 | time: 163.9 seconds\n",
            "Reached 100% win rate at epoch: 85\n",
            "files: model.h5, model.json\n",
            "n_epoch: 85, max_mem: 392, data: 32, time: 164.8 seconds\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "164.819645"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 15
        },
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOsAAADrCAYAAACICmHVAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAFR0lEQVR4nO3dMW5TaRiF4f+OEEiGEc1It0mJZHq7RTKrYAes4LbswNRIrCA9C4gXEBeU6SiQUKSUof6nmClmRCCxCPk4uc8juQroXIhfiKtv6L034Pf3R/UDADcjVgghVgghVgghVgghVgjx4JBf/PDhw75YLH7Vs/zQYrFoX758Kdl+/vx5e/z4ccn2169fbc9o+9OnT+3i4mK46msHxbpYLNqLFy9u56kOtNls2jRNJdvv3r1rm82mZHu329me0fZ6vf7u1/wYDCHECiHECiHECiHECiHECiHECiHECiHECiHECiHECiHECiHECiHECiHECiHECiHECiHECiHECiHECiHECiHECiHECiHECiHECiHECiHECiHECiEOOkz17Nmz9uHDh1/1LD+02+1a771su8p+v28vX74s2d5ut6XbVcehWmttGK485FZquC6AYRhet9Zet9baOI6r4+Pju3iub1xeXrYnT57Mbvv8/Lx9/vy5ZPvo6Kh0exzHku3Ly8t2dnZWsj1NU+u9X/0vRe/9xq/VatWrnJyczHJ7u9321lrJq3q7ysnJSdmf+58kr+7PZ1YIIVYIIVYIIVYIIVYIIVYIIVYIIVYIIVYIIVYIIVYIIVYIIVYIIVYIIVYIIVYIIVYIIVYIIVYIIVYIIVYIIVYIIVYIIVYIIVYIIVYIIVYIIVYIcVCs+/2+DcNQ8prr9mq1Ouh42G2+qrf5v4NOPj59+nT15s2bu3iub1SfH6zaXi6Xszx1Wb0df/KxFZ7Bqz4/WLU911OX1duV7/Xu5CNkEyuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEOCjW6hOAc9yuNsczm/v9vvS99t3vxXVviP+efBzHcXV8fHyrb4abqj4BONftqtOH1Sc+x3Es2Z6mqZ2env78ycfVatWrVJ8AnOt2m+GZze12W/Z3/m9jV/bnMyuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEiDn5eH5+XnoCcK7bVacPq09dVm3fi5OP1ScA57pdpfrUZRUnH+EeECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEECuEcPLxBpbL5SzPD9q+e04+/uRrrucHbd89Jx/hHhArhBArhBArhBArhBArhBArhBArhBArhBArhBArhBArhBArhBArhBArhBArhBArhBArhBArhBArhBArhBArhBArhBArhBArhBArhBArhBArhBArhIg5+TjXE4CVpy6Pjo7aOI4l29Xf70ePHpVsT9PUPn78eOXJxwfX/ebe+/vW2vvWWluv132z2dzu093Qbrdrc9x++/Ztm6apZHu73bZXr16VbFd/v5fLZcn2j/gxGEKIFUKIFUKIFUKIFUKIFUKIFUKIFUKIFUKIFUKIFUKIFUKIFUKIFUKIFUKIFUKIFUKIFUKIFUKIFUKIFUKIFUKIFUKIFUKIFUKIFUKIFUKIFUKIFUIcdPKxtbZsrZ396of6jr9aaxe2bd/z7WXv/c+rvnBtrL+LYRhOe+9r27bnuu3HYAghVgiRFOt727bnvB3zmRXmLul/Vpg1sUIIsUIIsUIIsUKIvwFeHJLQ+CueIQAAAABJRU5ErkJggg==\n",
            "text/plain": [
              "<Figure size 432x288 with 1 Axes>"
            ]
          },
          "metadata": {
            "tags": [],
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "XUAWVTZtKfnK",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 248
        },
        "outputId": "5b650add-b6a1-4156-a696-b9d6abfa9a67"
      },
      "source": [
        ""
      ],
      "execution_count": 16,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "error",
          "ename": "NameError",
          "evalue": "ignored",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
            "\u001b[0;32m<ipython-input-16-a7ffd00781e6>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0mdrive\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmount\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'/content/drive'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msave\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'/content/drive/My Drive/db/dqn/maze/dqn/wrdb'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mwrdb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      5\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msave\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'/content/drive/My Drive/db/dqn/maze/dqn/edb'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0medb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msave\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'/content/drive/My Drive/db/dqn/maze/dqn/epdb'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mepdb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;31mNameError\u001b[0m: name 'wrdb' is not defined"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "usl0p3vMOacR",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "\"\"\"\n",
        "Reached 100% win rate at epoch: 85\n",
        "files: model.h5, model.json\n",
        "n_epoch: 85, max_mem: 392, data: 32, time: 164.8 seconds\n",
        "164.819645\n",
        "\"\"\""
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}